import os
import sys
import math
import torch
import os.path as osp
import torchvision.utils as tvu

sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4]))
from tqdm import tqdm

__all__ = ['ContinuousTDiffusion', 'beta_schedule', 'interpolate_fn', 'TimestepMapping']

def _i(tensor, t, x):
    r"""Index tensor using t and format the output according to x.
    """
    shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
    return tensor[t.cpu()].view(shape).to(x)

class TimestepMapping:
    r"""map timestep that is fed into the model.
    """
    def __init__(self, method, num_timesteps=1000, sample_method='uniform_t'):
        assert method in ['inverse_sigmoid', '1-t']
        assert sample_method in ['uniform_t', 'uniform_out']
        self.method = method
        self.num_timesteps = num_timesteps
        self.sample_method = sample_method

    def __call__(self, t):
        T = self.num_timesteps
        # normalize to [1 / T, 1]
        t = (t + 1) / T
        if self.method == 'inverse_sigmoid':
            return torch.log((t / (1 - t)).abs()).clamp(max=1e5)
        elif self.method == '1-t':
            return 1 / t    # [1, T]

    def sample(self, batchsize, device):
        T = self.num_timesteps
        if self.sample_method == 'uniform_t':
            t = (T - 1) * torch.rand((batchsize, ), device=device)
        elif self.sample_method == 'uniform_out':
            out = torch.rand((batchsize, ), device=device)
            start, end = self.remap_t_range()
            MIN, MAX = min(start, end), max(start, end)
            out = (MAX - MIN) * out + MIN
            t = self.inverse(out)
        return t

    def inverse(self, out):
        T = self.num_timesteps
        if self.method == 'inverse_sigmoid':
            return (T * torch.nn.functional.sigmoid(out) - 1).clamp(min=0)
        elif self.method == '1-t':
            # out: [1, T] -> t: [0, T - 1] 
            return T / out - 1

    def remap_t_range(self):
        T = self.num_timesteps
        if self.method == 'inverse_sigmoid':
            TMIN = 1 / T
            MIN = math.log(TMIN / (1 - TMIN))
            MAX = 1e5
            return MIN, MAX
        elif self.method == '1-t':
            return T, 1


def beta_schedule(schedule, num_timesteps=1000, init_beta=None, last_beta=None):
    if schedule == 'linear':
        scale = 1000.0 / num_timesteps
        init_beta = init_beta or scale * 0.0001
        last_beta = last_beta or scale * 0.02
        return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
    elif schedule == 'quadratic':
        init_beta = init_beta or 0.0015
        last_beta = last_beta or 0.0195
        return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
    elif schedule == 'cosine':
        betas = []
        for step in range(num_timesteps):
            t1 = step / num_timesteps
            t2 = (step + 1) / num_timesteps
            fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2) ** 2
            betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
        return torch.tensor(betas, dtype=torch.float64)
    elif schedule == "cosine_shift":
        betas = []
        for step in range(num_timesteps):
            t1 = step / num_timesteps
            t2 = (step + 1) / num_timesteps
            gn = lambda u: math.tan((u + 0.008) / 1.008 * math.pi / 2)
            snr_ = lambda u: -2 * torch.log(torch.tensor(gn(u))) + 2 * torch.log(torch.tensor(1/4))
            alpha_ = lambda u: torch.sigmoid(snr_(u))
            betas.append(min(1.0 - alpha_(t2) / alpha_(t1), 0.999))
        return torch.tensor(betas, dtype=torch.float64)
    else:
        raise ValueError(f'Unsupported schedule: {schedule}')

# To align with the discrete-time diffusion (e.g. DDPM), we set the continuous schedule by interpolating the discrete schedule
def interpolate_fn(x, xp, yp):
    N, K = x.shape[0], xp.shape[1]
    all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
    sorted_all_x, x_indices = torch.sort(all_x, dim=2)
    x_idx = torch.argmin(x_indices, dim=2)
    cand_start_idx = x_idx - 1
    start_idx = torch.where(
        torch.eq(x_idx, 0),
        torch.tensor(1, device=x.device),
        torch.where(
            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
        ),
    )
    end_idx = torch.where(
        torch.eq(start_idx, cand_start_idx),
        start_idx + 2, start_idx + 1)
    start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
    end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
    start_idx2 = torch.where(
        torch.eq(x_idx, 0),
        torch.tensor(0, device=x.device),
        torch.where(
            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
        ),
    )
    y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
    start_y = torch.gather(
        y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
    ).squeeze(2)
    end_y = torch.gather(
        y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
    ).squeeze(2)
    cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
    return cand

class ContinuousTDiffusion(object):

    def __init__(self,
                 betas,
                 mean_type='eps',
                 var_type='learned_range',
                 loss_type='mse',
                 rescale_timesteps=False,
                 timestep_map=None):
        # check input
        if not isinstance(betas, torch.DoubleTensor):
            betas = torch.tensor(betas, dtype=torch.float64)
        assert min(betas) > 0 and max(betas) <= 1
        assert mean_type in ['eps', 'v']
        assert var_type in ['fixed_large', 'fixed_small']
        assert loss_type in ['mse']
        self.betas = betas
        self.num_timesteps = len(betas)
        self.mean_type = mean_type
        self.var_type = var_type
        self.loss_type = loss_type
        self.rescale_timesteps = rescale_timesteps
        self.timestep_map = timestep_map if timestep_map is not None else lambda t: t

        # alphas
        alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]])
        self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])])

        # q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)

        self.sqrt_recip_sigma = torch.sqrt(1.0 / (1 - self.alphas_cumprod))
        self.sqrt_recipm1_sigma = torch.sqrt(1.0 / (1 - self.alphas_cumprod) - 1)

        # q(x_{t-1} | x_t, x_0)
        self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20))
        self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod)
    
        # diffusion parameters for continuous t interpolation
        self.t_array = torch.arange(self.num_timesteps, dtype=self.alphas_cumprod.dtype).reshape(1, -1)
        self.log_alphas_array = 0.5 * torch.log(self.alphas_cumprod).reshape(1, -1)
        self.sqrt_alphas_cumprod._call = lambda log_alpha : torch.exp(log_alpha)
        self.sqrt_one_minus_alphas_cumprod._call = lambda log_alpha : torch.sqrt(1. - torch.exp(2. * log_alpha))
        self.alphas_cumprod._call = lambda log_alpha : torch.exp(2 * log_alpha)
        self.sqrt_recip_alphas_cumprod._call = lambda log_alpha : torch.sqrt(1.0 / torch.exp(2. * log_alpha))
        self.sqrt_recipm1_alphas_cumprod._call = lambda log_alpha : torch.sqrt(1.0 / torch.exp(2. * log_alpha) - 1)

    def _i(self, fn, t, x):
        shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
        if isinstance(fn, torch.Tensor) and t.dtype == torch.long:
            return fn[t].view(shape).to(x)
        elif hasattr(fn, '_call'):
            log_alpha = interpolate_fn(
                t.reshape((-1, 1)),
                self.t_array.to(t.device),
                self.log_alphas_array.to(t.device)
            ).reshape(-1)
            return fn._call(log_alpha).view(shape).to(x)
        else:
            raise ValueError()

    def q_sample(self, x0, t, noise=None):
        noise = torch.randn_like(x0) if noise is None else noise
        return self._i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
               self._i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise

    def q_mean_variance(self, x0, t):
        mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
        var = _i(1.0 - self.alphas_cumprod, t, x0)
        log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
        return mu, var, log_var
    
    def q_posterior_mean_variance(self, x0, xt, t):
        mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt
        var = _i(self.posterior_variance, t, xt)
        log_var = _i(self.posterior_log_variance_clipped, t, xt)
        return mu, var, log_var

    def v_prediction_groundtruth(self, noise, x0, t):
        v = self._i(self.sqrt_alphas_cumprod, t, x0) * noise - \
            self._i(self.sqrt_one_minus_alphas_cumprod, t, x0) * x0
        return v
    
    @torch.no_grad()
    def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
        mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)

        noise = torch.randn_like(xt)
        mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
        if condition_fn is not None:
            grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
            mu = mu.float() + var * grad.float()
        xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
        return xt_1, x0
    
    @torch.no_grad()
    def p_sample_loop(self, noise, model, interval=100, return_inter=False, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
        b = noise.size(0)
        xt = noise
        intermediate = [xt]
        for step in torch.arange(self.num_timesteps).flip(0):
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, x0 = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale)
            if step % interval == 0 or step == self.num_timesteps - 1:
                intermediate.append(x0)
        if return_inter:
            return xt, intermediate
        else:
            return xt
    
    def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None):
        if guide_scale is None:
            out = model(xt, self.timestep_map(self._scale_timesteps(t)), **model_kwargs)
        else:
            assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
            y_out = model(xt, self.timestep_map(self._scale_timesteps(t)), **model_kwargs[0])
            u_out = model(xt, self.timestep_map(self._scale_timesteps(t)), **model_kwargs[1])
            dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2
            out = torch.cat([
                u_out[:, :dim] + guide_scale * (y_out[:, :dim] - u_out[:, :dim]),
                y_out[:, dim:]], dim=1)

        if t.dtype != torch.long:
            log_var, var = None, None
        else:
            if self.var_type == 'learned':
                out, log_var = out.chunk(2, dim=1)
                var = torch.exp(log_var)
            elif self.var_type == 'learned_range':
                out, fraction = out.chunk(2, dim=1)
                min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
                max_log_var = _i(torch.log(self.betas), t, xt)
                fraction = (fraction + 1) / 2.0
                log_var = fraction * max_log_var + (1 - fraction) * min_log_var
                var = torch.exp(log_var)
            elif self.var_type == 'fixed_large':
                var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt)
                log_var = torch.log(var)
            elif self.var_type == 'fixed_small':
                var = _i(self.posterior_variance, t, xt)
                log_var = _i(self.posterior_log_variance_clipped, t, xt)
              
        if self.mean_type == 'x_{t-1}':
            mu = out
            x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
                 _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt
        elif self.mean_type == 'x0':
            x0 = out
            mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
        elif self.mean_type == 'eps':
            x0 = self._i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 self._i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
            if clamp is not None:
                x0 = x0.clamp(-clamp, clamp)
            if t.dtype == torch.long:
                mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
            else:
                mu = None
        elif self.mean_type == 'v':
            x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out
            if clamp is not None:
                x0 = x0.clamp(-clamp, clamp)
            mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
        
        if percentile is not None:
            assert percentile > 0 and percentile <= 1  # e.g., 0.995
            s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1)
            x0 = torch.min(s, torch.max(-s, x0)) / s
        elif clamp is not None:
            x0 = x0.clamp(-clamp, clamp)
        return mu, var, log_var, x0

    @torch.no_grad()
    def ddim_sample(self, xt, t, t_pre, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
        _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
        if condition_fn is not None:
            alpha = self._i(self.alphas_cumprod, t, xt)
            eps = (self._i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  self._i(self.sqrt_recipm1_alphas_cumprod, t, xt)
            eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)

            x0 = self._i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 self._i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
        
        eps = (self._i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
              self._i(self.sqrt_recipm1_alphas_cumprod, t, xt)
        alphas = self._i(self.alphas_cumprod, t, xt)
        alphas_prev = self._i(self.alphas_cumprod, t_pre.clamp(0), xt)
        sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))

        noise = torch.randn_like(xt)
        direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps
        mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
        xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
        return xt_1, x0
    
    @torch.no_grad()
    def ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, scondt=-1, econdt=1000, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
        b = noise.size(0)
        xt = noise

        start_steps = (1 + torch.arange(0, scondt, 1)).clamp(0, self.num_timesteps - 1).flip(0) if scondt > 1 else None
        skip_steps = (1 + torch.arange(scondt, econdt, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
        end_steps = (1 + torch.arange(econdt, self.num_timesteps, 1)).clamp(0, self.num_timesteps - 1).flip(0) if econdt < 999 else None
        if end_steps is not None: 
            steps = torch.cat([end_steps, skip_steps], dim=0)
        else:
            steps = skip_steps
        if start_steps is not None:
            steps = torch.cat([steps, start_steps], dim=0)
        else:
            steps = steps
        pre_steps = torch.cat([steps[1:], torch.zeros_like(steps)[:1]], dim=0)
        steps_iter = enumerate(zip(pre_steps, steps))

        for k, (pre_step, step) in steps_iter:
            t_pre = torch.full((b, ), pre_step, dtype=torch.long, device=xt.device)
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.ddim_sample(xt, t, t_pre, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta)
        return xt

    @torch.no_grad()
    def ddim_remap_t_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=1000, eta=0.0,
                                 progress_bar=False, return_intermediate=False):
        assert isinstance(self.timestep_map, TimestepMapping)
        b = noise.size(0)
        xt = noise
        intermediate = {'x0': [], 'xt': []}

        start, end = self.timestep_map.remap_t_range()
        steps = self.timestep_map.inverse(torch.linspace(start, end, ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
        pre_steps = torch.cat([steps[1:], torch.zeros_like(steps)[:1]], dim=0)
        steps_iter = tqdm(enumerate(zip(pre_steps, steps)), total=len(steps), desc='DDIM-remap_t', disable=not progress_bar)

        for k, (pre_step, step) in steps_iter:
            t_pre = torch.full((b, ), pre_step, device=xt.device)
            t = torch.full((b, ), step, device=xt.device)
            xt, x0 = self.ddim_sample(xt, t, t_pre, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta)
            if return_intermediate:
                intermediate['x0'].append(x0)
                intermediate['xt'].append(xt)
        return x0, intermediate

    @torch.no_grad()
    def gt_ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
        b = noise.size(0)
        xt = noise

        steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
        pre_steps = torch.cat([steps[1:], torch.zeros_like(steps)[:1]], dim=0)
        steps_iter = enumerate(zip(pre_steps, steps))

        ret_xt = [noise]

        for k, (pre_step, step) in steps_iter:
            t_pre = torch.full((b, ), pre_step, dtype=torch.long, device=xt.device)
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.ddim_sample(xt, t, t_pre, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta)
            ret_xt.append(xt)
        return ret_xt

    def loss(self, x0, t, model, model_kwargs={}, noise=None):
        noise = torch.randn_like(x0) if noise is None else noise
        xt = self.q_sample(x0, t, noise=noise)

        # compute loss
        if self.loss_type in ['kl', 'rescaled_kl']:
            loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs)
            if self.loss_type == 'rescaled_kl':
                loss = loss * self.num_timesteps
        elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
            out = model(xt, self.timestep_map(self._scale_timesteps(t)), **model_kwargs)

            # VLB for variation
            loss_vlb = 0.0
            if self.var_type in ['learned', 'learned_range']:
                out, var = out.chunk(2, dim=1)
                frozen = torch.cat([out.detach(), var], dim=1)
                loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
                if self.loss_type.startswith('rescaled_'):
                    loss_vlb = loss_vlb * self.num_timesteps / 1000.0
            
            # MSE/L1 for x0/eps
            if self.mean_type == 'eps':
                target = noise
            elif self.mean_type == 'v':
                target = self.v_prediction_groundtruth(noise, x0, t)
            else:
                raise NotImplementedError()
            loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
            
            # total loss
            loss = loss + loss_vlb
        elif self.loss_type in ['mse_sqrt']:
            out = model(xt, self.timestep_map(self._scale_timesteps(t)), **model_kwargs)
            target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], 'v': self.v_prediction_groundtruth(noise, x0, t)}[self.mean_type]
            loss = torch.nn.functional.mse_loss(target, out)
        return loss

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * 1000.0 / self.num_timesteps
        return t